{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "tqrD7Yzlmlsk"
      },
      "source": [
        "##### Copyright 2019 The TensorFlow Authors."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "2k8X1C1nmpKv"
      },
      "outputs": [],
      "source": [
        "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "# https://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "32xflLc4NTx-"
      },
      "source": [
        "# カスタムフェデレーテッドアルゴリズム、パート 2: フェデレーテッドアベレージングの実装"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "jtATV6DlqPs0"
      },
      "source": [
        "<table class=\"tfo-notebook-buttons\" align=\"left\">\n",
        "  <td><a target=\"_blank\" href=\"https://www.tensorflow.org/federated/tutorials/custom_federated_algorithms_2\"><img src=\"https://www.tensorflow.org/images/tf_logo_32px.png\"> TensorFlow.orgで表示</a></td>\n",
        "  <td><a target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/docs-l10n/blob/master/site/ja/federated/tutorials/custom_federated_algorithms_2.ipynb\"><img src=\"https://www.tensorflow.org/images/colab_logo_32px.png\">Run in Google Colab</a></td>\n",
        "  <td><a target=\"_blank\" href=\"https://github.com/tensorflow/docs-l10n/blob/master/site/ja/federated/tutorials/custom_federated_algorithms_2.ipynb\"><img src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\">GitHub でソースを表示{</a></td>\n",
        "</table>"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "_igJ2sfaNWS8"
      },
      "source": [
        "このチュートリアルは、[フェデレーテッドコア（FC）](../federated_core.md)を使用して TFF でフェデレーテッドアルゴリズムのカスタムタイプを実装する方法を示す 2 部較正シリーズのうちの第 2 部です。これは、[フェデレーテッド学習（FL）](../federated_learning.md)レイヤー（`tff.learning`）の基礎として機能します。\n",
        "\n",
        "まずは[このシリーズの第 1 部](custom_federated_algorithms_1.ipynb)を読むことをお勧めします。第 1 部では、ここで使用される主な概念とプログラミング抽出が説明されています。\n",
        "\n",
        "シリーズの第 2 部では、第 1 部で導入したメカニズムを使用して、単純なバージョンのフェデレーテッドトレーニングと評価アルゴリズムを実装します。\n",
        "\n",
        "TFF のフェデレーテッド学習 API に関するより概要レベルのわかりやすい説明については、[画像分類器](federated_learning_for_image_classification.ipynb)と[テキスト生成](federated_learning_for_text_generation.ipynb)のチュートリアルをご覧ください。ここで説明する概念を文脈的に当てはめる上で役に立ちます。"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "cuJuLEh2TfZG"
      },
      "source": [
        "## 始める前に\n",
        "\n",
        "始める前に、以下の「Hello World」の例を実行して、環境が正しく設定されていることを確認してください。動作しない場合は、[インストール](../install.md)ガイドをご覧ください。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "rB1ovcX1mBxQ"
      },
      "outputs": [],
      "source": [
        "#@test {\"skip\": true}\n",
        "!pip install --quiet --upgrade tensorflow_federated_nightly\n",
        "!pip install --quiet --upgrade nest_asyncio\n",
        "\n",
        "import nest_asyncio\n",
        "nest_asyncio.apply()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "-skNC6aovM46"
      },
      "outputs": [],
      "source": [
        "import collections\n",
        "\n",
        "import numpy as np\n",
        "import tensorflow as tf\n",
        "import tensorflow_federated as tff\n",
        "\n",
        "# TODO(b/148678573,b/148685415): must use the reference context because it\n",
        "# supports unbounded references and tff.sequence_* intrinsics.\n",
        "tff.backends.reference.set_reference_context()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "zzXwGnZamIMM"
      },
      "outputs": [
        {
          "data": {
            "text/plain": [
              "'Hello, World!'"
            ]
          },
          "execution_count": 4,
          "metadata": {
            "tags": []
          },
          "output_type": "execute_result"
        }
      ],
      "source": [
        "@tff.federated_computation\n",
        "def hello_world():\n",
        "  return 'Hello, World!'\n",
        "\n",
        "hello_world()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "iu5Gd8D6W33s"
      },
      "source": [
        "## フェデレーテッドアベレージングを実装する\n",
        "\n",
        "[画像分類器用のフェデレーテッド学習](federated_learning_for_image_classification.ipynb)のように MNIST サンプルを使用しますが、これは低レベルのチュートリアルを目的としているため、Keras API と `tff.simulation` をバイパスし、生のモデルのコードを記述して、ゼロからフェデレーテッドデータセットを構築することにします。\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "b6qCjef350c_"
      },
      "source": [
        "### フェデレーテッドデータセットを準備する\n",
        "\n",
        "実演の目的により、10 人のユーザーから得たデータを使用し、各ユーザーが異なる数字を識別する方法に関する知識を貢献するというシナリオをシミュレーションすることにします。これは、ほぼ [i.i.d](https://en.wikipedia.org/wiki/Independent_and_identically_distributed_random_variables) です。\n",
        "\n",
        "最初に、標準 MNIST データを読み込みましょう。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "uThZM4Ds-KDQ"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz\n",
            "11493376/11490434 [==============================] - 0s 0us/step\n",
            "11501568/11490434 [==============================] - 0s 0us/step\n"
          ]
        }
      ],
      "source": [
        "mnist_train, mnist_test = tf.keras.datasets.mnist.load_data()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "PkJc5rHA2no_"
      },
      "outputs": [
        {
          "data": {
            "text/plain": [
              "[(dtype('uint8'), (60000, 28, 28)), (dtype('uint8'), (60000,))]"
            ]
          },
          "execution_count": 6,
          "metadata": {
            "tags": []
          },
          "output_type": "execute_result"
        }
      ],
      "source": [
        "[(x.dtype, x.shape) for x in mnist_train]"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "mFET4BKJFbkP"
      },
      "source": [
        "データはNumpy 配列として提供されます。1 つは画像を含み、もう 1 つは数字ラベルを含みます。どちらも最初の次元は個々の例に適用されます。フェデレーテッドシーケンスを TFF 計算にフィードする方法と互換性のある方法、つまりリストのリストとしてフォーマットするヘルパー関数を記述しましょう。各クライアントのシーケンスにおけうｒユーザーに及ぶ外部リスト（数値）とデータのバッチに及ぶ内部リストです。このようなバッチを、通例どおり、`x` と `y` という名前のついた、それぞれに主要なバッチ次元のあるテンソルのペアとして構成します。これを行いながら、各画像を 784 要素のベクトルに平板化し、そのピクセルを `0..1` の範囲にスケーリングし直して、データ変換によってモデルのロジックが混乱しないようにします。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "XTaTLiq5GNqy"
      },
      "outputs": [],
      "source": [
        "NUM_EXAMPLES_PER_USER = 1000\n",
        "BATCH_SIZE = 100\n",
        "\n",
        "\n",
        "def get_data_for_digit(source, digit):\n",
        "  output_sequence = []\n",
        "  all_samples = [i for i, d in enumerate(source[1]) if d == digit]\n",
        "  for i in range(0, min(len(all_samples), NUM_EXAMPLES_PER_USER), BATCH_SIZE):\n",
        "    batch_samples = all_samples[i:i + BATCH_SIZE]\n",
        "    output_sequence.append({\n",
        "        'x':\n",
        "            np.array([source[0][i].flatten() / 255.0 for i in batch_samples],\n",
        "                     dtype=np.float32),\n",
        "        'y':\n",
        "            np.array([source[1][i] for i in batch_samples], dtype=np.int32)\n",
        "    })\n",
        "  return output_sequence\n",
        "\n",
        "\n",
        "federated_train_data = [get_data_for_digit(mnist_train, d) for d in range(10)]\n",
        "\n",
        "federated_test_data = [get_data_for_digit(mnist_test, d) for d in range(10)]"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "xpNdBimWaMHD"
      },
      "source": [
        "サニティーチェックをさっと行うために、5 番目のクライアント（数値 `5` に対応するクライアント）が貢献したデータの最後のバッチにある `Y` テンソルを確認しましょう。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "bTNuL1W4bcuc"
      },
      "outputs": [
        {
          "data": {
            "text/plain": [
              "array([5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5,\n",
              "       5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5,\n",
              "       5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5,\n",
              "       5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5,\n",
              "       5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5], dtype=int32)"
            ]
          },
          "execution_count": 8,
          "metadata": {
            "tags": []
          },
          "output_type": "execute_result"
        }
      ],
      "source": [
        "federated_train_data[5][-1]['y']"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Xgvcwv7Obhat"
      },
      "source": [
        "確認のため、そのバッチの最後の要素に対応する画像も見てみましょう。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "cI4aat1za525"
      },
      "outputs": [
        {
          "data": {
            "image/png": "iVBORw0KGgoAAAANSUhEUgAAAP8AAAD8CAYAAAC4nHJkAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4zLCBo\ndHRwOi8vbWF0cGxvdGxpYi5vcmcvnQurowAAEBZJREFUeJzt3W1sU2Ufx/HfHgJKAOOQzo6hY9mE\nyTbINrMYZQjIgyaOh4Vn4mCEJeALA6hB34AJgSVCggmoVIgsMWogymZAFoiAyoxZivQFSAxOFsco\nG8gMAxU2PPcLc++WW3q1dH2C6/tJTrL239Pzz9l+O+25TnslOY7jCIB1kuPdAID4IPyApQg/YCnC\nD1iK8AOWIvyApQg/YCnCD1iK8AOWSo3lxpKSkmK5OcBKoV6026cjf0NDg0aOHKmcnBzV1NT05akA\nxJoTpp6eHic7O9tpbm52rl+/7hQWFjqnTp0yriOJhYUlykuowj7yNzU1KScnR9nZ2erXr5/mzZun\n+vr6cJ8OQIyFHf62tjYNHz6893ZmZqba2tr+9TiPx6OSkhKVlJSEuykAURD2Cb/bnVS43Qm96upq\nVVdXB6wDiI+wj/yZmZlqbW3tvX3u3DllZGREpCkAMRDuCb/u7m5nxIgRzs8//9x7wu/kyZOc8GNh\nifMSqrBf9qempmrr1q2aOnWqbt68qaqqKo0ePTrcpwMQY0nO7d68R2tjvOcHoi7USHN5L2Apwg9Y\nivADliL8gKUIP2Apwg9YivADliL8gKUIP2Apwg9YivADliL8gKUIP2Apwg9YivADliL8gKUIP2Ap\nwg9YivADliL8gKUIP2CpmE7RjcTz0ksvGevXrl0z1nft2hXBbm6VlZVlrCcnm49dc+bMCVgbNmyY\ncd0VK1YY688++6yxfuTIEWM9EXDkByxF+AFLEX7AUoQfsBThByxF+AFLEX7AUn0a58/KytKgQYOU\nkpKi1NRUeb3eSPWFGJkxY4axPmHCBGN9yJAhxrrP5wtYW7BggXHdRYsWGespKSnGel9cvXrVWO/s\n7IzatmOlzxf5HDlyRA899FAkegEQQ7zsByzVp/AnJSVpypQpKi4ulsfjiVRPAGKgTy/7GxsblZGR\noY6ODk2ePFmjRo1SWVnZLY/xeDz8YwASUJ+O/BkZGZIkl8ulmTNnqqmp6V+Pqa6ultfr5WQgkGDC\nDv+1a9fU1dXV+/PBgweVn58fscYARFfYL/vb29s1c+ZMSVJPT48WLFigadOmRawxANGV5DiOE7ON\nJSXFalMI0aFDh4z1YOP8wX6nMfzzuiMrV6401g8cOGCs//TTT5FsJ6JC3ecM9QGWIvyApQg/YCnC\nD1iK8AOWIvyApfjq7nuAabjtqaeeMq47fvz4SLcTsj/++MNY/+9FZIE0NDQY6+vXrw9YO3v2rHHd\nRB2ijCSO/IClCD9gKcIPWIrwA5Yi/IClCD9gKcIPWIqP9N4DBg4cGLD222+/RXXbN27cMNY///zz\ngLVNmzYZ1+Xbn8LDR3oBGBF+wFKEH7AU4QcsRfgBSxF+wFKEH7AUn+e/B8yePTtu216xYoWxvmvX\nrtg0gjvGkR+wFOEHLEX4AUsRfsBShB+wFOEHLEX4AUsFHeevqqrSvn375HK5dPLkSUnS5cuXNXfu\nXLW0tCgrK0u7d+/Wgw8+GPVmbTVnzhxjfcuWLVHb9jvvvGOsM45/9wp65F+8ePG/JkeoqanRpEmT\ndObMGU2aNEk1NTVRaxBAdAQNf1lZmdLS0m65r76+XpWVlZKkyspK1dXVRac7AFET1nv+9vZ2ud1u\nSZLb7VZHR0dEmwIQfVG/tt/j8cjj8UR7MwDuUFhH/vT0dPn9fkmS3++Xy+UK+Njq6mp5vV6+jBFI\nMGGFv7y8XLW1tZKk2tpaTZ8+PaJNAYi+oOGfP3++nnzySf3444/KzMzUzp07tWbNGh06dEi5ubk6\ndOiQ1qxZE4teAUQQ39ufAAYMGGCsf/vtt8Z6fn5+2Ns+fPiwsV5RUWGsd3V1hb1tRAff2w/AiPAD\nliL8gKUIP2Apwg9YivADluKru2Ogf//+xvr27duN9b4M5QWzceNGY52hvHsXR37AUoQfsBThByxF\n+AFLEX7AUoQfsBThByzFOH8MPPPMM8b6/PnzY9PIbcyaNctYLywsNNavXLlirH/wwQd33BNigyM/\nYCnCD1iK8AOWIvyApQg/YCnCD1iK8AOW4qu7Y2D//v3G+rRp02LUSeQlJ5uPH/X19QFrwfbLzp07\njfW//vrLWLcVX90NwIjwA5Yi/IClCD9gKcIPWIrwA5Yi/IClgo7zV1VVad++fXK5XDp58qQkad26\ndXr//fc1dOhQSdKGDRv0/PPPB9+YpeP8RUVFxvq7775rrBcXF4e97dOnTxvrfr/fWB8+fLix/thj\njxnrfbmMZM2aNcb6pk2bwn7ue1nExvkXL16shoaGf92/cuVK+Xw++Xy+kIIPILEEDX9ZWZnS0tJi\n0QuAGAr7Pf/WrVtVWFioqqoqdXZ2RrInADEQVviXL1+u5uZm+Xw+ud1urV69OuBjPR6PSkpKVFJS\nEnaTACIvrPCnp6crJSVFycnJWrZsmZqamgI+trq6Wl6vV16vN+wmAUReWOH/5xnivXv3RnUWWQDR\nEfSru+fPn6+jR4/q0qVLyszM1JtvvqmjR4/K5/MpKSlJWVlZQaeYBpB4+Dx/AhgwYICxnp2dHfZz\nt7W1GevBTtYOGTLEWB85cqSx/vrrrwesPffcc8Z1b968aazPmDHDWD9w4ICxfq/i8/wAjAg/YCnC\nD1iK8AOWIvyApQg/YCmG+iLg/vvvN9b//PNPYz2Gv4KYS0lJCVjz+XzGdfPy8oz1xsZGY338+PHG\n+r2KoT4ARoQfsBThByxF+AFLEX7AUoQfsBThBywV9PP8+NsDDzwQsPbRRx8Z1509e7ax/vvvv4fV\n091g4MCBAWv33Xdfn547NZU/377gyA9YivADliL8gKUIP2Apwg9YivADliL8gKUYKA2RabqxqVOn\nGtcNNo11sM+1JzLTOL4kffjhhwFrI0aMiHQ7uAMc+QFLEX7AUoQfsBThByxF+AFLEX7AUoQfsFTQ\ncf7W1la9+OKLunDhgpKTk1VdXa2XX35Zly9f1ty5c9XS0qKsrCzt3r1bDz74YCx6vus0NDQY66Zp\nrCVpz549kWznjixevNhYX7t2rbHel7+J7u5uY/29994L+7kRwpE/NTVVmzdv1unTp/Xdd99p27Zt\n+uGHH1RTU6NJkybpzJkzmjRpkmpqamLRL4AICRp+t9utoqIiSdKgQYOUl5entrY21dfXq7KyUpJU\nWVmpurq66HYKIKLu6D1/S0uLTpw4odLSUrW3t8vtdkv6+x9ER0dHVBoEEB0hX9t/9epVVVRUaMuW\nLRo8eHDIG/B4PPJ4PGE1ByB6Qjryd3d3q6KiQgsXLtSsWbMkSenp6fL7/ZIkv98vl8t123Wrq6vl\n9Xrl9Xoj1DKASAgafsdxtHTpUuXl5WnVqlW995eXl6u2tlaSVFtbq+nTp0evSwARF3SK7mPHjmnc\nuHEqKChQcvLf/ys2bNig0tJSzZkzR7/88oseeeQR7dmzR2lpaeaN3cVTdJeWlgasHT582Lhu//79\nI91Owgj2OzX9eXV2dhrXDTYEumPHDmPdVqFO0R30Pf/TTz8d8Mm+/PLLO+sKQMLgCj/AUoQfsBTh\nByxF+AFLEX7AUoQfsFTQcf6IbuwuHuc3WbJkibEe7KOnKSkpkWwnpoL9Ti9evBiwVlFRYVy3sbEx\nrJ5sF2qkOfIDliL8gKUIP2Apwg9YivADliL8gKUIP2ApxvljYNSoUcb6Z599ZqwHm+I7moJNH75v\n3z5j3XSNw4ULF8LqCWaM8wMwIvyApQg/YCnCD1iK8AOWIvyApQg/YCnG+YF7DOP8AIwIP2Apwg9Y\nivADliL8gKUIP2Apwg9YKmj4W1tbNWHCBOXl5Wn06NF6++23JUnr1q3TsGHDNHbsWI0dO1ZffPFF\n1JsFEDlBL/Lx+/3y+/0qKipSV1eXiouLVVdXp927d2vgwIF65ZVXQt8YF/kAURfqRT6pwR7gdrvl\ndrslSYMGDVJeXp7a2tr61h2AuLuj9/wtLS06ceKESktLJUlbt25VYWGhqqqq1NnZedt1PB6PSkpK\nVFJS0vduAUSOE6Kuri6nqKjI+fTTTx3HcZwLFy44PT09zs2bN5033njDWbJkSdDnkMTCwhLlJVQh\nPfLGjRvOlClTnM2bN9+2fvbsWWf06NGEn4UlAZZQBX3Z7ziOli5dqry8PK1atar3fr/f3/vz3r17\nlZ+fH+ypACSQoGf7jx07pnHjxqmgoEDJyX//r9iwYYM+/vhj+Xw+JSUlKSsrS9u3b+89MRhwY5zt\nB6IuSKR78Xl+4B4TaqS5wg+wFOEHLEX4AUsRfsBShB+wFOEHLEX4AUsRfsBShB+wFOEHLEX4AUsR\nfsBShB+wFOEHLBX0CzwjaciQIcrKyuq9ffHiRQ0dOjSWLYQsUXtL1L4kegtXJHtraWkJ+bEx/Tz/\n/yspKZHX643X5o0StbdE7Uuit3DFqzde9gOWIvyApVLWrVu3Lp4NFBcXx3PzRonaW6L2JdFbuOLR\nW1zf8wOIH172A5aKS/gbGho0cuRI5eTkqKamJh4tBJSVlaWCggKNHTs27lOMVVVVyeVy3TInwuXL\nlzV58mTl5uZq8uTJAadJi0dviTJzc6CZpeO97xJuxuuQp/eIkJ6eHic7O9tpbm52rl+/7hQWFjqn\nTp2KdRsBPfroo87Fixfj3YbjOI7z1VdfOcePH79lNqRXX33V2bhxo+M4jrNx40bntddeS5je1q5d\n67z11ltx6eefzp8/7xw/ftxxHMe5cuWKk5ub65w6dSru+y5QX/HabzE/8jc1NSknJ0fZ2dnq16+f\n5s2bp/r6+li3cVcoKytTWlraLffV19ersrJSklRZWam6urp4tHbb3hKF2+1WUVGRpFtnlo73vgvU\nV7zEPPxtbW0aPnx47+3MzMyEmvI7KSlJU6ZMUXFxsTweT7zb+Zf29vbemZHcbrc6Ojri3NGtQpm5\nOZb+ObN0Iu27cGa8jrSYh9+5zeBCIs3k09jYqO+//14HDhzQtm3b9PXXX8e7pbvG8uXL1dzcLJ/P\nJ7fbrdWrV8e1n6tXr6qiokJbtmzR4MGD49rLP/1/X/HabzEPf2ZmplpbW3tvnzt3ThkZGbFuI6D/\n9uJyuTRz5kw1NTXFuaNbpaen906S6vf75XK54tzR/6SnpyslJUXJyclatmxZXPddd3e3KioqtHDh\nQs2aNau3v3jvu0B9xWO/xTz8TzzxhM6cOaOzZ8/qxo0b+uSTT1ReXh7rNm7r2rVr6urq6v354MGD\nCTf7cHl5uWprayVJtbW1mj59epw7+p9EmbnZCTCzdLz3XaC+4rbfYn6K0XGc/fv3O7m5uU52draz\nfv36eLRwW83NzU5hYaFTWFjoPP7443Hvbd68ec7DDz/spKamOsOGDXN27NjhXLp0yZk4caKTk5Pj\nTJw40fn1118TprdFixY5+fn5TkFBgfPCCy8458+fj0tv33zzjSPJKSgocMaMGeOMGTPG2b9/f9z3\nXaC+4rXfuMIPsBRX+AGWIvyApQg/YCnCD1iK8AOWIvyApQg/YCnCD1jqP1hNIrYb+rn+AAAAAElF\nTkSuQmCC\n",
            "text/plain": [
              "<Figure size 600x400 with 1 Axes>"
            ]
          },
          "metadata": {
            "tags": []
          },
          "output_type": "display_data"
        }
      ],
      "source": [
        "from matplotlib import pyplot as plt\n",
        "\n",
        "plt.imshow(federated_train_data[5][-1]['x'][-1].reshape(28, 28), cmap='gray')\n",
        "plt.grid(False)\n",
        "plt.show()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "J-ox58PA56f8"
      },
      "source": [
        "### TensorFlow と TFF の組み合わせについて\n",
        "\n",
        "このチュートリアルでは、読みやすくまとめるために、TensorFlow ロジックを導入する関数をすぐに `tff.tf_computation` でデコレートしていますが、より複雑なロジックの場合にお勧めするパターンではありません。TensorFlow のデバッグはすでに困難であり、完全にシリアル化された上で再インポートされた後の TensorFlow のデバッグは、必然的にメタデータを失ってインタラクティビティを制限するため、デバッグがさらに困難となります。\n",
        "\n",
        "そのため、**複雑な TF ロジックはスタンドアロンの Python 関数として記述することを強くお勧めします**（`tff.tf_computation` デコレーションを使用せずに、ということです）。こうすることで、TFF への計算をシリアル化する前に TF のベストプラクティスとツール（Eager モードなど）を使用して（Python 関数を引数として`tff.tf_computation` を呼び出すことで）、TensorFlow ロジックを開発およびテストすることができます。"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "RSd6UatXbzw-"
      },
      "source": [
        "### 損失関数を定義する\n",
        "\n",
        "データの用意ができたため、トレーニングに使用する損失関数を定義しましょう。まず、入力のタイプを TFF の名前付きタプルとして定義思案す。データバッチのサイズは多様であるため、バッチの次元を `None` に設定して、弧の次元のサイズが不明であることを指定します。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "653xv5NXd4fy"
      },
      "outputs": [
        {
          "data": {
            "text/plain": [
              "'<x=float32[?,784],y=int32[?]>'"
            ]
          },
          "execution_count": 10,
          "metadata": {
            "tags": []
          },
          "output_type": "execute_result"
        }
      ],
      "source": [
        "BATCH_SPEC = collections.OrderedDict(\n",
        "    x=tf.TensorSpec(shape=[None, 784], dtype=tf.float32),\n",
        "    y=tf.TensorSpec(shape=[None], dtype=tf.int32))\n",
        "BATCH_TYPE = tff.to_type(BATCH_SPEC)\n",
        "\n",
        "str(BATCH_TYPE)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "pb6qPUvyh5A1"
      },
      "source": [
        "なぜ通常の Python タイプを定義できないのかと思っているかもしれませんが、[第 1 部](custom_federated_algorithms_1.ipynb)において、Python を使って TFF 計算のロジックを表現できたとしても、内部的には TFF 計算は Python として*扱われない*と説明したことを思い出してください。上記で定義されたシンボル `BATCH_TYPE` は、抽象的な TFF タイプの仕様を表します。この*抽象的な* TFF タイプと具象的な Python *式*タイプ（Python 関数の本文で TFF タイプを表現するために使用される `dict` や `collections.namedtuple` といったコンテナなど）を区別することが重要です。Python とは異なり、TFF には、タプル式コンテナ用に単一の抽象型コンストラクタ `tff.StructType` があり、個別に名前を付けたり無名のままにできる要素が伴います。TFF 計算は仮に 1 つのパラメータと 1 つの結果のみを定義できるため、このタイプは計算の仮パラメータもモデルするために使用できます。この例については以下の方で紹介されています。\n",
        "\n",
        "では、TFF タイプのモデルパラメータを定義しましょう。もう一度、TFF の*重み*と*バイアス*の名前付きタプルとして定義します。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Og7VViafh-30"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "<weights=float32[784,10],bias=float32[10]>\n"
          ]
        }
      ],
      "source": [
        "MODEL_SPEC = collections.OrderedDict(\n",
        "    weights=tf.TensorSpec(shape=[784, 10], dtype=tf.float32),\n",
        "    bias=tf.TensorSpec(shape=[10], dtype=tf.float32))\n",
        "MODEL_TYPE = tff.to_type(MODEL_SPEC)\n",
        "\n",
        "print(MODEL_TYPE)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "iHhdaWSpfQxo"
      },
      "source": [
        "上記の定義が完了したところで、シングルバッチにおける特定のモデルの損失を定義することができます。`@tff.tf_computation` 内に `@tf.function` デコレータが使用されるところに注意してください。こうすることで、`tff.tf_computation` デコレータが作成した `tf.Graph` のコンテキスト内であっても、Python 式のセマンティクスを使用して TF を記述することができます。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "4EObiz_Ke0uK"
      },
      "outputs": [],
      "source": [
        "# NOTE: `forward_pass` is defined separately from `batch_loss` so that it can \n",
        "# be later called from within another tf.function. Necessary because a\n",
        "# @tf.function  decorated method cannot invoke a @tff.tf_computation.\n",
        "\n",
        "@tf.function\n",
        "def forward_pass(model, batch):\n",
        "  predicted_y = tf.nn.softmax(\n",
        "      tf.matmul(batch['x'], model['weights']) + model['bias'])\n",
        "  return -tf.reduce_mean(\n",
        "      tf.reduce_sum(\n",
        "          tf.one_hot(batch['y'], 10) * tf.math.log(predicted_y), axis=[1]))\n",
        "\n",
        "@tff.tf_computation(MODEL_TYPE, BATCH_TYPE)\n",
        "def batch_loss(model, batch):\n",
        "  return forward_pass(model, batch)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "8K0UZHGnr8SB"
      },
      "source": [
        "期待した通り、`batch_loss` の計算では、モデルと単一のデータバッチに `float32` 型の損失が返されます。`MODEL_TYPE` と `BATCH_TYPE` が 2 タプルの仮パラメータにひとまとめにされており、`batch_loss` の型を `(<MODEL_TYPE,BATCH_TYPE> -> float32)` として認識することができます。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "4WXEAY8Nr89V"
      },
      "outputs": [
        {
          "data": {
            "text/plain": [
              "'(<<weights=float32[784,10],bias=float32[10]>,<x=float32[?,784],y=int32[?]>> -> float32)'"
            ]
          },
          "execution_count": 13,
          "metadata": {
            "tags": []
          },
          "output_type": "execute_result"
        }
      ],
      "source": [
        "str(batch_loss.type_signature)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "pAnt_UcdnvGa"
      },
      "source": [
        "サニティーチェックとして、ゼロが埋められた初期モデルを構築し、上記で視覚化したデータのバッチでの損失を計算しましょう。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "U8Ne8igan3os"
      },
      "outputs": [
        {
          "data": {
            "text/plain": [
              "2.3025854"
            ]
          },
          "execution_count": 14,
          "metadata": {
            "tags": []
          },
          "output_type": "execute_result"
        }
      ],
      "source": [
        "initial_model = collections.OrderedDict(\n",
        "    weights=np.zeros([784, 10], dtype=np.float32),\n",
        "    bias=np.zeros([10], dtype=np.float32))\n",
        "\n",
        "sample_batch = federated_train_data[5][-1]\n",
        "\n",
        "batch_loss(initial_model, sample_batch)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ckigEAyDAWFV"
      },
      "source": [
        "`dict` として定義された初期モデルで TFF 計算をフィードします。これは、それを定義する Python 関数の本文が、`model['weight']` と `model['bias']` としてモデルパラメータを消費する場合でもです。`batch_loss` への呼び出しの引数は、単純にその関数の本文に渡されません。\n",
        "\n",
        "`batch_loss` を呼び出す場合に何が起きるでしょうか。`batch_loss` の Python 本文は、それが定義された上記のセルでトレースされ、シリアル化されています。TFF は、計算の定義時に `batch_loss` の呼び出し元として機能し、`batch_loss` の呼び出し時に呼び出しのターゲットとして機能します。両方の役割において、TFF は TFF の象徴タイプシステムと Python 式表現タイプの橋渡しとして動作します。呼び出し時には、TFF はほとんどの標準 Python コンテナタイプ（`dict`、`list`、`tuple`、`collections.namedtuple` など）を抽象的な TFF タプルの具象表現として受け入れます。また、上記で述べたように、TFF 計算は仮に単一のパラメータのみを受け入れるため、パラメータの型がタプルである場合に位置やキーワード引数を使った使い慣れた Python 呼び出しを使用することができ、期待通りに動作します。"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "eB510nILYbId"
      },
      "source": [
        "### シングルバッチの勾配降下\n",
        "\n",
        "では、この損失関数を使用して単一ステップの勾配降下を実行する計算を定義しましょう。この関数を定義するにおいて、`batch_loss` をサブコンポーネントとして使用することに注意してください。別の計算の本文内に `tff.tf_computation` で構築された計算を呼び出すことはできますが、通常そうする必要はありません。上記で述べたように、シリアル化によってデバッグ情報の一部が損なわれるため、より複雑な計算で、`tff.tf_computation` デコレータを使用せずにすべての TensorFlow を記述してテストすることが好ましいといえます。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "O4uaVxw3AyYS"
      },
      "outputs": [],
      "source": [
        "@tff.tf_computation(MODEL_TYPE, BATCH_TYPE, tf.float32)\n",
        "def batch_train(initial_model, batch, learning_rate):\n",
        "  # Define a group of model variables and set them to `initial_model`. Must\n",
        "  # be defined outside the @tf.function.\n",
        "  model_vars = collections.OrderedDict([\n",
        "      (name, tf.Variable(name=name, initial_value=value))\n",
        "      for name, value in initial_model.items()\n",
        "  ])\n",
        "  optimizer = tf.keras.optimizers.SGD(learning_rate)\n",
        "\n",
        "  @tf.function\n",
        "  def _train_on_batch(model_vars, batch):\n",
        "    # Perform one step of gradient descent using loss from `batch_loss`.\n",
        "    with tf.GradientTape() as tape:\n",
        "      loss = forward_pass(model_vars, batch)\n",
        "    grads = tape.gradient(loss, model_vars)\n",
        "    optimizer.apply_gradients(\n",
        "        zip(tf.nest.flatten(grads), tf.nest.flatten(model_vars)))\n",
        "    return model_vars\n",
        "\n",
        "  return _train_on_batch(model_vars, batch)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Y84gQsaohC38"
      },
      "outputs": [
        {
          "data": {
            "text/plain": [
              "'(<<weights=float32[784,10],bias=float32[10]>,<x=float32[?,784],y=int32[?]>,float32> -> <weights=float32[784,10],bias=float32[10]>)'"
            ]
          },
          "execution_count": 16,
          "metadata": {
            "tags": []
          },
          "output_type": "execute_result"
        }
      ],
      "source": [
        "str(batch_train.type_signature)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ID8xg9FCUL2A"
      },
      "source": [
        "こういった別の関数の本文内で `tff.tf_computation` でデコレートされた Python 関数を呼び出す場合、内部の TFF 計算のロジックは外部の TFF 計算のロジックに埋め込まれます（基本的にインライン）。前述のとおり、両方の計算を記述している場合は、内部関数（この場合 `batch_loss`）を、`tff.tf_computation` ではなく、通常の Python または `tf.function` にするのが好ましいといえますが、ここでは別の関数内で 1 つの `tff.tf_computation` を呼び出しても期待通りに動作します。これは、`batch_loss` を定義している Python コードがないが、シリアル化された TFF 表現しかない場合に、必要となります。\n",
        "\n",
        "では、この関数を初期モデルに何度か適用して、損失が縮減するかどうかを見てみましょう。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "8edcJTlXUULm"
      },
      "outputs": [],
      "source": [
        "model = initial_model\n",
        "losses = []\n",
        "for _ in range(5):\n",
        "  model = batch_train(model, sample_batch, 0.1)\n",
        "  losses.append(batch_loss(model, sample_batch))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "3n1onojT1zHv"
      },
      "outputs": [
        {
          "data": {
            "text/plain": [
              "[0.19690022, 0.13176313, 0.10113226, 0.082738124, 0.0703014]"
            ]
          },
          "execution_count": 18,
          "metadata": {
            "tags": []
          },
          "output_type": "execute_result"
        }
      ],
      "source": [
        "losses"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "EQk4Ha8PU-3P"
      },
      "source": [
        "### 一連のローカルデータの勾配降下\n",
        "\n",
        "`batch_train` が動作しているようなので、あるユーザーのｎシングルバッチからだけでなく、すべてのバッチの全シーケンスを消費する `local_train` という類似したトレーニング関数を記述しましょう。新しい計算では、`BATCH_TYPE` ではなく `tff.SequenceType(BATCH_TYPE)` を消費する必要があります。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "EfPD5a6QVNXM"
      },
      "outputs": [],
      "source": [
        "LOCAL_DATA_TYPE = tff.SequenceType(BATCH_TYPE)\n",
        "\n",
        "@tff.federated_computation(MODEL_TYPE, tf.float32, LOCAL_DATA_TYPE)\n",
        "def local_train(initial_model, learning_rate, all_batches):\n",
        "\n",
        "  # Mapping function to apply to each batch.\n",
        "  @tff.federated_computation(MODEL_TYPE, BATCH_TYPE)\n",
        "  def batch_fn(model, batch):\n",
        "    return batch_train(model, batch, learning_rate)\n",
        "\n",
        "  return tff.sequence_reduce(all_batches, initial_model, batch_fn)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "sAhkS5yKUgjC"
      },
      "outputs": [
        {
          "data": {
            "text/plain": [
              "'(<<weights=float32[784,10],bias=float32[10]>,float32,<x=float32[?,784],y=int32[?]>*> -> <weights=float32[784,10],bias=float32[10]>)'"
            ]
          },
          "execution_count": 20,
          "metadata": {
            "tags": []
          },
          "output_type": "execute_result"
        }
      ],
      "source": [
        "str(local_train.type_signature)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "EYT-SiopYBtH"
      },
      "source": [
        "この短いコードセクションには、たくさんの詳細が含められているため、1 つずつ確認することにします。\n",
        "\n",
        "まず、先に行ったように、このロジックを TensorFlow だけで実装し、`tf.data.Dataset.reduce` に頼ってシーケンスを処理することも可能ですが、今回は、`tff.federated_computation` として、グルー言語でロジックを表現することを省略しています。還元を実行するために、フェデレーテッド演算子 `tff.sequence_reduce` を使用しました。\n",
        "\n",
        "演算子 `tff.sequence_reduce` は `tf.data.Dataset.reduce` と同じように使用されています。基本的には `tf.data.Dataset.reduce` と同じように考えることができますが、フェデレーテッド計算内で使用する場合、覚えていらっしゃるかもしれませんが、TensorFlow コードを含めることができません。これは、`T` タイプの要素の*シーケンス*、あるタイプ `U` の還元の初期状態（これを抽象的に*ゼロ*と呼びます）、単一要素を処理することで還元の状態をアラートするタイプ `(<U,T> -> U)` の*還元演算子*で較正される仮の 3 組タプルのパラメータを使用する、テンプレート演算子です。すべての要素を順に処理した後の還元の最終状態が結果となります。この例では、還元の状態は、データの接頭辞でトレーニングされたモデルで、要素はデータのバッチです。\n",
        "\n",
        "次に、再度、1 つの計算を（`batch_train`）を別の計算（`local_train`）内のコンポーネントとして使用していますが、直接的でないところに注意してください。これは学習率という別のパラメータを追加で使用するため、還元演算子としては使用することができないためです。これを解決するために、本文で `local_train` のパラメータ `learning_rate` にバインドする埋め込みフェデレーテッド計算 `batch_fn` を定義します。子計算が親計算の本文の外で呼び出されない限り、このように定義された子計算が親計算の仮のパラメータをキャプチャすることができます。このパターンは、Python の `functools.partial` に相当すると捉えることができます。\n",
        "\n",
        "この方法で `learning_rate` を取得した場合には、もちろん実際的に、すべてのバッチで同じ学習率の値が使用されることになります。\n",
        "\n",
        "では、新たに定義したローカルトレーニング関数を、サンプルバッチ（数値 `5`）を貢献した同一ユーザーのデータの全シーケンスで試してみましょう。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "EnWFLoZGcSby"
      },
      "outputs": [],
      "source": [
        "locally_trained_model = local_train(initial_model, 0.1, federated_train_data[5])"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "y0UXUqGk9zoF"
      },
      "source": [
        "うまくいきましたか。この問いに答えるには、評価を実装する必要があります。"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "a8WDKu6WYy__"
      },
      "source": [
        "### ローカル評価\n",
        "\n",
        "次は、すべてのデータバッチの損失を加算してローカル評価を実装した一例です（平均を計算することもできましたが、読み手の演習として、取っておくことにします）。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "0RiODuc6z7Ln"
      },
      "outputs": [],
      "source": [
        "@tff.federated_computation(MODEL_TYPE, LOCAL_DATA_TYPE)\n",
        "def local_eval(model, all_batches):\n",
        "  # TODO(b/120157713): Replace with `tff.sequence_average()` once implemented.\n",
        "  return tff.sequence_sum(\n",
        "      tff.sequence_map(\n",
        "          tff.federated_computation(lambda b: batch_loss(model, b), BATCH_TYPE),\n",
        "          all_batches))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "pH2XPEAKa4Dg"
      },
      "outputs": [
        {
          "data": {
            "text/plain": [
              "'(<<weights=float32[784,10],bias=float32[10]>,<x=float32[?,784],y=int32[?]>*> -> float32)'"
            ]
          },
          "execution_count": 23,
          "metadata": {
            "tags": []
          },
          "output_type": "execute_result"
        }
      ],
      "source": [
        "str(local_eval.type_signature)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "efX81HuE-BcO"
      },
      "source": [
        "また、このコードで説明される新しい要素がいくつかあるので、1 つずつ確認しましょう。\n",
        "\n",
        "まず、シーケンスの処理に、2 つの新しいフェデレーテッド演算子を使用しています。*マッピング関数* <code>T->U</code> と `T` の<em>sequence</em>を取って、マッピング関数をポイントごとに適用して取得される `U` のシーケンスを放出する <code>tff.sequence_map</code> と、すべての要素をただ加算する <code>tff.sequence_sum</code> です。ここでは、各データバッチを損失値にマッピングし、結果の損失値を加算して合計損失を計算します。\n",
        "\n",
        "`tff.sequence_reduce` をもう一度使用することもできましたが、これは最善の選択肢ではありません。還元プロセスは、その定義上シーケンシャルですが、マッピングと加算は並行して計算されるからです。選択肢がある場合は、実装の選択肢に制約を与えない演算子を使用することが最善であり、TFF 計算が将来的に特定の環境にデプロイされるようにコンパイルした場合に、より高速でスケーラブルな、リソース効率に優れた実行を行うためのあらゆる潜在的な機会をまるごと活用することができます。\n",
        "\n",
        "次に、`local_train` と同じように、必要なコンポーネント関数（`batch_loss`）はフェデレーテッド演算子（`tff.sequence_map`）が期待する数より多いパラメータを取るため、もう一度部分的な定義を行っています。今回は、`lambda` を `tff.federated_computation` として直接タッピングしてインライン化しています。`tff.tf_computation` を使用して TFF に TensorFlow ロジックを埋め込むには、関数を引数として、ラッパーをインラインで使用することが推奨されます。\n",
        "\n",
        "では、トレーニングが機能するかどうかを確認しましょう。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "vPw6JSVf5q_x"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "initial_model loss = 23.025854\n",
            "locally_trained_model loss = 0.4348469\n"
          ]
        }
      ],
      "source": [
        "print('initial_model loss =', local_eval(initial_model,\n",
        "                                         federated_train_data[5]))\n",
        "print('locally_trained_model loss =',\n",
        "      local_eval(locally_trained_model, federated_train_data[5]))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "6Tvu70cnBsUf"
      },
      "source": [
        "まさに、損失が減少しました。しかし、別のユーザーのデータで評価する場合はどうでしょうか。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "gjF0NYAj5wls"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "initial_model loss = 23.025854\n",
            "locally_trained_model loss = 74.50075\n"
          ]
        }
      ],
      "source": [
        "print('initial_model loss =', local_eval(initial_model,\n",
        "                                         federated_train_data[0]))\n",
        "print('locally_trained_model loss =',\n",
        "      local_eval(locally_trained_model, federated_train_data[0]))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "7WPumnRTBzUs"
      },
      "source": [
        "期待通り、悪化してしまいました。モデルは `5` を認識するようにトレーニングされており、`0` に遭遇したことがありません。そこから次のような疑問が生じます。グローバルな観点では、ローカルトレーニングはモデルの質にどのような影響を与えたのでしょうか。"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "QJnL2mQRZKTO"
      },
      "source": [
        "### フェデレーテッド評価\n",
        "\n",
        "巡り巡ってようやくフェデレーテッドタイプとフェデレーテッド計算にたどり着きました。開始点のトピックです。次は、サーバーをオリジンとするモデルとクライアントに残るデータの TFF タイプの定義です。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "LjGGhpoEBh_6"
      },
      "outputs": [],
      "source": [
        "SERVER_MODEL_TYPE = tff.type_at_server(MODEL_TYPE)\n",
        "CLIENT_DATA_TYPE = tff.type_at_clients(LOCAL_DATA_TYPE)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "4gTXV2-jZtE3"
      },
      "source": [
        "すべての定義の導入が完了しているため、TFF でフェデレーテッド評価を表現するだけとなりました。モデルをクライアントに配布し、各クライアントにローカル部分のデータでのローカル評価を呼び出させ、損失の平均を割り出します。次は、これを記述した一例です。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "2zChEPzEBx4T"
      },
      "outputs": [],
      "source": [
        "@tff.federated_computation(SERVER_MODEL_TYPE, CLIENT_DATA_TYPE)\n",
        "def federated_eval(model, data):\n",
        "  return tff.federated_mean(\n",
        "      tff.federated_map(local_eval, [tff.federated_broadcast(model), data]))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "IWcNONNWaE0N"
      },
      "source": [
        "`tff.federated_mean` と `tff.federated_map` については、より単純なシナリオでの例をすでに見てきました。直感的なレベルでは期待通りに動作しますが、このセクションのコードには、目に見える以上の詳細があるため、注意深く確認することにしましょう。\n",
        "\n",
        "まず、「*各クライアントにローカル部分のデータでのローカル評価を呼び出させ*」という部分を解析してみましょう。前のセクションで述べたように、`local_eval` には `(<MODEL_TYPE, LOCAL_DATA_TYPE> -> float32)` という形式のタイプシグネチャがあります。\n",
        "\n",
        "フェデレーテッド演算子 `tff.federated_map` は、パラメータとして、あるタイプの `T->U` の*マッピング関数*とタイプ `{T}@CLIENTS` のフェデレーテッド値（マッピング関数のパラメータと同じタイプのメンバー構成要素）で構成される2 組タプルを受け入れ、タイプ `{U}@CLIENTS` の結果を返すテンプレートです。\n",
        "\n",
        "`local_eval` をクライアントごとに適用するマッピング関数としてフィードしているため、2 つ目の引数はフェデレーテッドタイプ `{<MODEL_TYPE, LOCAL_DATA_TYPE>}@CLIENTS` である必要があります。前のセクションの命名法を使えば、フェデレーテッドタプルです。各クライアントは、メンバー構成要素として`local_eval` の全引数を格納しなければなりませんが、代わりに 2 要素の Python `list` をフィードしてみます。どうなりますか。\n",
        "\n",
        "実際のところ、これは TFF の*暗黙の型変換*の例で、`float` を受け入れる関数に `int` をフィードした場合など、ほかの場所で見たことのある暗黙の型変換に似ています。暗黙の変換はこの時点ではほとんど使用されることはありませんが、TFF では、ボイラープレートを最小限に抑える手段として、さらに普及させることを予定しています。\n",
        "\n",
        "この場合に適用される暗黙の変換は、`{<X,Y>}@Z` という形式のフェデレーテッドタプルと、`<{X}@Z,{Y}@Z>` というフェデレーテッド値の間の変換に相当します。正式には、これらの 2 つは別々のタイプシグネチャですが、プログラマーの観点から言えば、`Z` の各デバイスは、データ `X` と `Y` の 2 つのユニットを格納していることになります。ここで起きることは Python の `zip` と何ら変わりなく、このような変換を明示的に実行できる `tff.federated_zip` 演算子を提供しているため、`tff.federated_map` が 2 番目の引数としてタプルに遭遇すると、単にユーザーに代わって `tff.federated_zip` を起動することができます。\n",
        "\n",
        "上記を踏まえると、`tff.federated_broadcast(model)` という式を TFF タイプ `{MODEL_TYPE}@CLIENTS` の値を表現するもの、そして `data` を TFF タイプ `{LOCAL_DATA_TYPE}@CLIENTS`（または単に `CLIENT_DATA_TYPE`）の値として認識できるはずです。2つは、暗黙の `tff.federated_zip ` を介して一緒にフィルタリングされ、`tff.federated_map` の 2 番目の引数を形成します。\n",
        "\n",
        "演算子 `tff.federated_broadcast` は、ご察しのとおり、データをサーバーからクライアントに転送します。\n",
        "\n",
        "では、ローカルトレーニングによって、システムの平均損失にどのような影響が与えられたかを確認しましょう。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "tbmtJItcn94j"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "initial_model loss = 23.025852\n",
            "locally_trained_model loss = 54.432625\n"
          ]
        }
      ],
      "source": [
        "print('initial_model loss =', federated_eval(initial_model,\n",
        "                                             federated_train_data))\n",
        "print('locally_trained_model loss =',\n",
        "      federated_eval(locally_trained_model, federated_train_data))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "LQi2rGX_fK7i"
      },
      "source": [
        "まさに期待される通り、損失は増加しています。すべてのユーザーに対してこのモデルを改善するために、全員のデータでトレーニングする必要があります。"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "vkw9f59qfS7o"
      },
      "source": [
        "### フェデレーテッドトレーニング\n",
        "\n",
        "フェデレーテッドトレーニングの最も単純な実装方法は、ローカルでトレーニングしてからモデルを平均化する方法です。これは、次のコードで示される通り、すでに説明したのと同じビルディングブロックとパターンを使用します。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "mBOC4uoG6dd-"
      },
      "outputs": [],
      "source": [
        "SERVER_FLOAT_TYPE = tff.FederatedType(tf.float32, tff.SERVER)\n",
        "\n",
        "\n",
        "@tff.federated_computation(SERVER_MODEL_TYPE, SERVER_FLOAT_TYPE,\n",
        "                           CLIENT_DATA_TYPE)\n",
        "def federated_train(model, learning_rate, data):\n",
        "  return tff.federated_mean(\n",
        "      tff.federated_map(local_train, [\n",
        "          tff.federated_broadcast(model),\n",
        "          tff.federated_broadcast(learning_rate), data\n",
        "      ]))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "z2vACMsQjzO1"
      },
      "source": [
        "モデルを平均化する代わりに、`tff.learning` が提供するフェデレーテッドアベレージングの全機能実装において、更新基準や圧縮などのさまざまな理由により、モデルのデルタを平均化したいと思います。\n",
        "\n",
        "トレーニングを数ラウンド実行し、実行前と実行後の平均損失を比較して、トレーニングが機能しているかどうかを確認しましょう。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "NLx-3rLs9jGY"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "round 0, loss=21.60552406311035\n",
            "round 1, loss=20.365678787231445\n",
            "round 2, loss=19.27480125427246\n",
            "round 3, loss=18.31110954284668\n",
            "round 4, loss=17.45725440979004\n"
          ]
        }
      ],
      "source": [
        "model = initial_model\n",
        "learning_rate = 0.1\n",
        "for round_num in range(5):\n",
        "  model = federated_train(model, learning_rate, federated_train_data)\n",
        "  learning_rate = learning_rate * 0.9\n",
        "  loss = federated_eval(model, federated_train_data)\n",
        "  print('round {}, loss={}'.format(round_num, loss))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Z0VjSLQzlUIp"
      },
      "source": [
        "完全性を期するために、テストデータでも実行して、モデルがうまく一般化されることを確認しましょう。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "ZaZT45yFMOaM"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "initial_model test loss = 22.795593\n",
            "trained_model test loss = 17.278767\n"
          ]
        }
      ],
      "source": [
        "print('initial_model test loss =',\n",
        "      federated_eval(initial_model, federated_test_data))\n",
        "print('trained_model test loss =', federated_eval(model, federated_test_data))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "pxlHHwLGlgFB"
      },
      "source": [
        "これで、チュートリアルは完了です。\n",
        "\n",
        "もちろん、単純化されたこの例では、損失以外のメトリックを計算していないなど、より現実的なシナリオで実行する必要のある数々の内容が反映されていません。より完全な例として、また、私たちがお勧めするコーディング実践を実演するために、`tff.learning` でのフェデレーテッドアベレージングの[実装](https://github.com/tensorflow/federated/blob/master/tensorflow_federated/python/learning/federated_averaging.py)を学習することをお勧めします。"
      ]
    }
  ],
  "metadata": {
    "colab": {
      "collapsed_sections": [],
      "name": "custom_federated_algorithms_2.ipynb",
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
